/*
 Copyright 2013--Present JMM_PROGNAME

 This file is distributed under the terms of the JMM_PROGNAME License.

 You should have received a copy of the JMM_PROGNAME License.
 If not, see <JMM_PROGNAME WEBSITE>.
*/
// CREATED    : 11/8/2013
// LAST UPDATE: 12/30/2014


// STL
#include <algorithm>           // std::min_element
#include <iostream>

// jSCIENCE
#include "jutility.hpp"        // get_n_rand_unique_elements(), remove_multiple_elements()

// stats++
#include "statsxx/dataset.hpp" // DataSet


//========================================================================
//========================================================================
//
// NAME: void balance_training_set(DataSet &TS)
//
// DESC: Balances a classiciation training set, ensuring equal (almost, at least) proportions of output. For classification training sets ONLY.
//
//========================================================================
//========================================================================
inline void balance_training_set(DataSet &TS)
{
    int nout = TS.pt[0].out.size();


    // GET INDICES FOR ALL CLASSIFICATION (1.0) OUTPUTS ...
    std::vector<std::vector<decltype(TS.pt.size())>> classification_indices(nout);
    for( decltype(TS.pt.size()) i = 0; i < TS.pt.size(); ++i )
    {
        for( int j = 0; j < nout; ++j )
        {
            if( TS.pt[i].out[j] == 1.0 )
            {
                classification_indices[j].push_back(i);
                break;
            }
        }
    }


    // GET ALL INDICES TO DELETE TO BALANCE TYPES OF OUTPUT ...
    std::vector<decltype(classification_indices.size())> to_del;

    // note: we have to handle binomial classifications separately, since only the 0 position is stored
    if( nout == 1 )
    {
        // note: we need to delete the excess defined by [x - (n - x)] ...
        int n = 2*classification_indices[0].size() - TS.pt.size();

        // MORE 1.0 CLASSIFICATIONS ...
        if( n > 0 )
        {
            to_del = get_n_rand_unique_elements( n, classification_indices[0] );
        }
        // ... MORE 0.0 CLASSIFICATIONS ...
        else if( n < 0 )
        {
            std::vector<decltype(TS.pt.size())> unclassification_indices;
            for( decltype(TS.pt.size()) i = 0; i < TS.pt.size(); ++i )
            {
                if( TS.pt[i].out[0] == 0.0 )
                {
                    unclassification_indices.push_back(i);
                }
            }

            to_del = get_n_rand_unique_elements( -n, unclassification_indices);
        }
    }
    else
    {
        std::vector<int> sizes(nout);
        for( int i = 0; i < nout; ++i )
        {
            sizes[i] = classification_indices[i].size();
        }

        int min_set      = std::distance( sizes.begin(), std::min_element(sizes.begin(), sizes.end()) );
        int min_set_size = *std::min_element(sizes.begin(), sizes.end());


        for( int i = 0; i < nout; ++i )
        {
            if( i == min_set )
            {
                continue;
            }

            std::vector<decltype(classification_indices.size())> tmp = get_n_rand_unique_elements( (sizes[i] - min_set_size), classification_indices[i]);
            to_del.insert( to_del.end(), tmp.begin(), tmp.end() );
        }
    }


    remove_multiple_elements(to_del, TS.pt);
}
